다항 회귀

선형 기저 함수 모형

일반적인 선형 회귀 모형은 다음과 같은 수식을 만족한다.

$$ y_i = \sum_{i=1}^{D} w_i x_i = w^T x $$

이 때 가중치 벡터 $w$의 차원은 독립 변수의 차원과 같다. 즉 $x \in \mathbf{R}^D$ 이면 $w \in \mathbf{R}^D$ 이다.

선형 기저 함수 모형(Linear Basis Function Models)은 x를 직접 선형 조합(linear combination)하지 않고 기저 함수를 통해 변환한 값을 새로운 독립 변수로 가정하고 선형 회귀 모형을 적용한 것과 같다. 따라서 기저 함수에 따라 가중치 벡터의 차원이 달라질 수 있다. 즉, $ \phi(\cdot): \mathbf{R}^D \rightarrow \mathbf{R}^M $ 이면 $w \in \mathbf{R}^M$ 이다.

$$ y_i = \sum_{j=1}^{M} w_j \phi_j(x) = w^T \phi(x) $$

다항 회귀

다항 회귀는 다음과 같은 다항식 함수를 기저 함수로 사용하는 선형 기저 함수 모형의 일종이다.

$$ 1, x, x^2, \ldots, x^M $$

따라서 종속 변수와 독립 변수의 관계는 다음과 같이 표현할 수 있다.

$$ y = w_0 + w_1x + w_2x^2 + \ldots + w_M x^M $$

statsmodels를 이용한 다항 회귀

statsmodels에서는 OLS 클래스의 from_formula 메서드를 사용하여 다항 회귀를 할 수 있다.


In [32]:
np.random.seed(0)
n_samples = 30
X = np.sort(np.random.rand(n_samples))
y = np.cos(1.5 * np.pi * X) + np.random.randn(n_samples) * 0.1

dfX = pd.DataFrame(X, columns=["x"])
dfX = sm.add_constant(dfX)
dfy = pd.DataFrame(y, columns=["y"])
df = pd.concat([dfX, dfy], axis=1)

print(sm.OLS.from_formula("y ~ x", data=df).fit().summary())
print(sm.OLS.from_formula("y ~ x + I(x**2)", data=df).fit().summary())
print(sm.OLS.from_formula("y ~ x + I(x**2) + I(x**3)", data=df).fit().summary())
print(sm.OLS.from_formula("y ~ x + I(x**2) + I(x**3) + I(x**4)", data=df).fit().summary())
print(sm.OLS.from_formula("y ~ x + I(x**2) + I(x**3) + I(x**4) + I(x**5)", data=df).fit().summary())


                            OLS Regression Results                            
==============================================================================
Dep. Variable:                      y   R-squared:                       0.474
Model:                            OLS   Adj. R-squared:                  0.455
Method:                 Least Squares   F-statistic:                     25.20
Date:                Sun, 05 Jun 2016   Prob (F-statistic):           2.63e-05
Time:                        01:26:11   Log-Likelihood:                -20.253
No. Observations:                  30   AIC:                             44.51
Df Residuals:                      28   BIC:                             47.31
Df Model:                           1                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
Intercept      0.5367      0.207      2.595      0.015       0.113       0.960
x             -1.6093      0.321     -5.020      0.000      -2.266      -0.953
==============================================================================
Omnibus:                        5.427   Durbin-Watson:                   0.202
Prob(Omnibus):                  0.066   Jarque-Bera (JB):                2.523
Skew:                           0.425   Prob(JB):                        0.283
Kurtosis:                       1.862   Cond. No.                         4.85
==============================================================================

Warnings:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
                            OLS Regression Results                            
==============================================================================
Dep. Variable:                      y   R-squared:                       0.931
Model:                            OLS   Adj. R-squared:                  0.926
Method:                 Least Squares   F-statistic:                     181.6
Date:                Sun, 05 Jun 2016   Prob (F-statistic):           2.19e-16
Time:                        01:26:11   Log-Likelihood:                 10.181
No. Observations:                  30   AIC:                            -14.36
Df Residuals:                      27   BIC:                            -10.16
Df Model:                           2                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
Intercept      1.5411      0.107     14.378      0.000       1.321       1.761
x             -7.3196      0.444    -16.498      0.000      -8.230      -6.409
I(x ** 2)      5.5596      0.416     13.355      0.000       4.705       6.414
==============================================================================
Omnibus:                        3.753   Durbin-Watson:                   0.889
Prob(Omnibus):                  0.153   Jarque-Bera (JB):                1.543
Skew:                           0.062   Prob(JB):                        0.462
Kurtosis:                       1.896   Cond. No.                         23.0
==============================================================================

Warnings:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
                            OLS Regression Results                            
==============================================================================
Dep. Variable:                      y   R-squared:                       0.971
Model:                            OLS   Adj. R-squared:                  0.968
Method:                 Least Squares   F-statistic:                     289.3
Date:                Sun, 05 Jun 2016   Prob (F-statistic):           4.40e-20
Time:                        01:26:11   Log-Likelihood:                 23.183
No. Observations:                  30   AIC:                            -38.37
Df Residuals:                      26   BIC:                            -32.76
Df Model:                           3                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
Intercept      1.1694      0.094     12.418      0.000       0.976       1.363
x             -2.1962      0.904     -2.428      0.022      -4.055      -0.337
I(x ** 2)     -7.0567      2.125     -3.321      0.003     -11.424      -2.689
I(x ** 3)      8.2029      1.370      5.988      0.000       5.387      11.018
==============================================================================
Omnibus:                        2.444   Durbin-Watson:                   1.982
Prob(Omnibus):                  0.295   Jarque-Bera (JB):                1.471
Skew:                           0.530   Prob(JB):                        0.479
Kurtosis:                       3.228   Cond. No.                         160.
==============================================================================

Warnings:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
                            OLS Regression Results                            
==============================================================================
Dep. Variable:                      y   R-squared:                       0.973
Model:                            OLS   Adj. R-squared:                  0.969
Method:                 Least Squares   F-statistic:                     225.7
Date:                Sun, 05 Jun 2016   Prob (F-statistic):           3.17e-19
Time:                        01:26:11   Log-Likelihood:                 24.329
No. Observations:                  30   AIC:                            -38.66
Df Residuals:                      25   BIC:                            -31.65
Df Model:                           4                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
Intercept      1.0311      0.135      7.645      0.000       0.753       1.309
x              0.4675      2.089      0.224      0.825      -3.835       4.770
I(x ** 2)    -17.7895      7.900     -2.252      0.033     -34.060      -1.519
I(x ** 3)     23.5927     11.008      2.143      0.042       0.921      46.265
I(x ** 4)     -7.2629      5.156     -1.409      0.171     -17.882       3.357
==============================================================================
Omnibus:                        0.929   Durbin-Watson:                   2.192
Prob(Omnibus):                  0.628   Jarque-Bera (JB):                0.696
Skew:                           0.362   Prob(JB):                        0.706
Kurtosis:                       2.817   Cond. No.                         914.
==============================================================================

Warnings:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
                            OLS Regression Results                            
==============================================================================
Dep. Variable:                      y   R-squared:                       0.974
Model:                            OLS   Adj. R-squared:                  0.969
Method:                 Least Squares   F-statistic:                     182.9
Date:                Sun, 05 Jun 2016   Prob (F-statistic):           2.73e-18
Time:                        01:26:11   Log-Likelihood:                 25.117
No. Observations:                  30   AIC:                            -38.23
Df Residuals:                      24   BIC:                            -29.83
Df Model:                           5                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
Intercept      1.1269      0.158      7.118      0.000       0.800       1.454
x             -2.1789      3.118     -0.699      0.491      -8.615       4.257
I(x ** 2)      0.8313     18.154      0.046      0.964     -36.637      38.299
I(x ** 3)    -26.1798     45.097     -0.581      0.567    -119.255      66.895
I(x ** 4)     48.6674     49.427      0.985      0.335     -53.346     150.681
I(x ** 5)    -22.3170     19.616     -1.138      0.266     -62.802      18.168
==============================================================================
Omnibus:                        1.531   Durbin-Watson:                   2.272
Prob(Omnibus):                  0.465   Jarque-Bera (JB):                1.042
Skew:                           0.455   Prob(JB):                        0.594
Kurtosis:                       2.936   Cond. No.                     4.62e+03
==============================================================================

Warnings:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
[2] The condition number is large, 4.62e+03. This might indicate that there are
strong multicollinearity or other numerical problems.

Scikit-Learn을 이용한 다항 회귀

Scikit-Learn에서는 preprocessing 서브 패키지의 PolynomialFeatures 클래스를 사용하여 다항 회귀를 할 수 있다.


In [43]:
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline

def polyreg(degree, seed=0, plot=True):
    polynomial_features = PolynomialFeatures(degree=degree)
    linear_regression = LinearRegression()
    model = Pipeline([("polynomial_features", polynomial_features),
                      ("linear_regression", linear_regression)])

    np.random.seed(seed)
    n_samples = 30
    X = np.sort(np.random.rand(n_samples))
    y = np.cos(1.5 * np.pi * X) + np.random.randn(n_samples) * 0.1
    X = X[:, np.newaxis]

    model.fit(X, y)
    
    if plot:
        plt.scatter(X, y)
        xx = np.linspace(0, 1, 1000)
        plt.plot(xx, model.predict(xx[:, np.newaxis]))
        plt.ylim(-2, 2)
        plt.show()
    
    reg = model.named_steps["linear_regression"]
    return reg.coef_, reg.intercept_

In [44]:
polyreg(1)


Out[44]:
(array([ 0.        , -1.60931179]), 0.5366803303178731)

In [45]:
polyreg(2)


Out[45]:
(array([ 0.        , -7.31956683,  5.55955392]), 1.5411486872392566)

In [46]:
polyreg(3)


Out[46]:
(array([ 0.        , -2.19617614, -7.05669992,  8.202858  ]),
 1.1693941158316374)

In [47]:
polyreg(4)


Out[47]:
(array([  0.        ,   0.46754142, -17.78954475,  23.5926603 ,  -7.26289872]),
 1.0310751547291352)

In [48]:
polyreg(5)


Out[48]:
(array([  0.        ,  -2.17893409,   0.83127959, -26.17982339,
         48.6673596 , -22.31700272]), 1.1268720279438245)

In [59]:
polyreg(50)


Out[59]:
(array([ -3.54077870e+11,  -6.82363018e+05,   2.79580683e+07,
         -6.10015742e+08,   8.13342580e+09,  -7.10336646e+10,
          4.21563463e+11,  -1.72421906e+12,   4.80944108e+12,
         -8.67339556e+12,   8.40559920e+12,  -2.25138987e+11,
         -7.73102470e+12,   2.59082210e+12,   6.73888544e+12,
         -6.22777261e+11,  -7.19644198e+12,  -2.61772963e+12,
          5.24830476e+12,   5.89152381e+12,   3.33510205e+11,
         -5.57369036e+12,  -5.60583249e+12,  -7.45768258e+11,
          4.49445663e+12,   6.68123130e+12,   2.73265925e+12,
         -2.53228288e+12,  -6.92293100e+12,  -5.27044666e+12,
         -4.86708454e+11,   4.23739272e+12,   6.29663050e+12,
          4.91118368e+12,   2.32162149e+11,  -4.40488571e+12,
         -7.00682587e+12,  -4.95977402e+12,  -6.80200902e+10,
          4.90664146e+12,   7.30492141e+12,   4.90321297e+12,
         -7.12247605e+11,  -6.29995496e+12,  -7.55464838e+12,
         -3.05026616e+12,   5.16022515e+12,   9.30836324e+12,
          1.81107422e+12,  -1.21255316e+13,   4.75422005e+12]),
 354077875945.89655)

과최적화

모형을 특정 샘플 데이터에 대해 과도하게 최적화하는 것을 과최적화(overfitting)이라고 한다.

과최적화는

  • 독립 변수 데이터 갯수에 비해 모형 모수의 수가 과도하게 크거나
  • 독립 변수 데이터가 서로 독립이 아닌 경우에 발생한다.

이러한 상황에서는 같은 조건에 대해 답이 복수개 존재할 수 있기 때문이다.

과최적화가 문제가 되는 이유는 다음과 같다.

  • 샘플이 조금만 변화해도 fitting 결과가 크게 달라지며
  • 트레이닝에 사용되지 않은 새로운 독립 변수 값을 입력하면 오차가 커진다. (cross-validation 오차)

In [53]:
polyreg(2, 0)


Out[53]:
(array([ 0.        , -7.31956683,  5.55955392]), 1.5411486872392566)

In [54]:
polyreg(2, 1)


Out[54]:
(array([ 0.        , -6.23625845,  4.44609448]), 1.4254320912432297)

In [55]:
polyreg(2, 2)


Out[55]:
(array([ 0.        , -6.13954485,  4.0416645 ]), 1.4805367495708419)

In [56]:
polyreg(50, 1)


Out[56]:
(array([  1.03487567e+12,   6.41356614e+02,  -7.37707405e+04,
          3.47009207e+06,  -8.93425334e+07,   1.43225583e+09,
         -1.53479282e+10,   1.14481498e+11,  -6.06493990e+11,
          2.28756049e+12,  -6.02509343e+12,   1.04394093e+13,
         -9.88772860e+12,   5.74067175e+11,   7.93655296e+12,
         -2.11571627e+12,  -7.26374919e+12,   5.40889486e+11,
          6.25632612e+12,   3.13175832e+12,  -3.11863514e+12,
         -5.66954308e+12,  -1.84433963e+12,   2.70144437e+12,
          4.61155286e+12,   3.24644598e+12,  -9.30877381e+11,
         -3.53728583e+12,  -4.05867008e+12,  -2.19023729e+12,
          1.04446889e+12,   3.39688384e+12,   3.75237876e+12,
          2.25965772e+12,  -2.95700567e+11,  -2.20081256e+12,
         -3.49636422e+12,  -2.82580300e+12,  -1.09826744e+12,
          1.04143809e+12,   2.63070982e+12,   3.00813359e+12,
          2.22845559e+12,   1.73623796e+11,  -2.04361932e+12,
         -2.86597160e+12,  -2.32627169e+12,  -1.87521310e+11,
          2.61104198e+12,   3.07195305e+12,  -2.45638829e+12]),
 -1034875673148.4902)

In [57]:
polyreg(50, 2)


Out[57]:
(array([ -7.33596257e+11,   1.98861224e+04,  -9.43248464e+05,
          2.52484807e+07,  -4.31432775e+08,   5.01999215e+09,
         -4.12517490e+10,   2.43930690e+11,  -1.04190874e+12,
          3.17100035e+12,  -6.58648857e+12,   8.29163658e+12,
         -3.66439944e+12,  -4.74652291e+12,   4.94757219e+12,
          3.93591195e+12,  -3.73829162e+12,  -5.00770563e+12,
          5.90836540e+11,   4.98318527e+12,   3.64832419e+12,
         -1.10067321e+12,  -4.51313707e+12,  -4.12027976e+12,
         -8.19672577e+11,   2.77516465e+12,   4.48170959e+12,
          3.66138118e+12,   1.09440707e+12,  -1.80092632e+12,
         -3.77150708e+12,  -4.18490543e+12,  -3.10550266e+12,
         -1.07943119e+12,   1.15978231e+12,   2.95817284e+12,
          3.89999943e+12,   3.85346993e+12,   2.93656430e+12,
          1.43047269e+12,  -3.10675238e+11,  -1.94304300e+12,
         -3.18775017e+12,  -3.85941039e+12,  -3.87198838e+12,
         -3.22847188e+12,  -2.00091265e+12,  -3.06800669e+11,
          1.71279784e+12,   3.91227163e+12,   6.15544154e+12]),
 733596257309.59436)

In [58]:
polyreg(50, 3)


Out[58]:
(array([ -2.88684895e+10,  -1.69265013e+03,   1.07503904e+05,
         -3.73070301e+06,   7.95857714e+07,  -1.11268843e+09,
          1.06171355e+10,  -7.08626024e+10,   3.34107254e+11,
         -1.10429118e+12,   2.46777295e+12,  -3.35398796e+12,
          1.74640022e+12,   1.80319986e+12,  -2.36721052e+12,
         -1.37998701e+12,   1.99150601e+12,   1.93976460e+12,
         -8.94757268e+11,  -2.33146640e+12,  -9.72158268e+11,
          1.31850842e+12,   2.15693997e+12,   1.02333512e+12,
         -8.62327413e+11,  -1.95209435e+12,  -1.59054875e+12,
         -2.28980102e+11,   1.17887364e+12,   1.82278327e+12,
          1.43746654e+12,   3.43226116e+11,  -8.36434363e+11,
         -1.58552994e+12,  -1.58333428e+12,  -9.16967716e+11,
          9.40335778e+10,   1.02019879e+12,   1.51780721e+12,
          1.43869891e+12,   8.23700423e+11,  -8.52329855e+10,
         -9.58891283e+11,  -1.47194298e+12,  -1.42642727e+12,
         -7.96020792e+11,   2.42567098e+11,   1.28693518e+12,
          1.74708183e+12,   9.34712778e+11,  -1.91148853e+12]),
 28868489542.477337)